-
Notifications
You must be signed in to change notification settings - Fork 10.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RWKV v6: RWKV_WKV op CUDA implementation #9454
Conversation
7dd075a
to
8ec53bf
Compare
Signed-off-by: Molly Sophia <[email protected]>
8ec53bf
to
19f2a61
Compare
Hi, Molly. |
Signed-off-by: Molly Sophia <[email protected]>
19f2a61
to
7c39f2d
Compare
Yes. However I guess this is not that urgent. That can also be done after RWKV v7 is released, in the initial rwkv v7 support PR in the future. |
Hi! @ggerganov |
Signed-off-by: Molly Sophia <[email protected]>
Signed-off-by: Molly Sophia <[email protected]>
* ggml: CUDA unary op EXP Signed-off-by: Molly Sophia <[email protected]> * ggml: rwkv_wkv op CUDA impl Signed-off-by: Molly Sophia <[email protected]> --------- Signed-off-by: Molly Sophia <[email protected]>
Signed-off-by: Molly Sophia <[email protected]>
* ggml: CUDA unary op EXP Signed-off-by: Molly Sophia <[email protected]> * ggml: rwkv_wkv op CUDA impl Signed-off-by: Molly Sophia <[email protected]> --------- Signed-off-by: Molly Sophia <[email protected]>
Signed-off-by: Molly Sophia <[email protected]>
* ggml: CUDA unary op EXP Signed-off-by: Molly Sophia <[email protected]> * ggml: rwkv_wkv op CUDA impl Signed-off-by: Molly Sophia <[email protected]> --------- Signed-off-by: Molly Sophia <[email protected]>
Signed-off-by: Molly Sophia <[email protected]>
Added the RWKV_WKV CUDA impl and a test_case in test-backend-ops.cpp.
Also added unary op exp for cuda so that the rwkv v6 graph can be less splited when running on a gpu.
The kernel is modified from https://github.com/BlinkDL/ChatRWKV/blob/main/rwkv_pip_package/src/rwkv/cuda/rwkv6.cu and added support for batched inference.
Gonna add some speed and other test results later tomorrow.Prompt:
Here's the speed comparasion between the original and the PR version.
The test is done on my weird 12900HK ES + RTX4090 PC, which is relatively CPU-bound. The tests are all using FP16, offloading all layers to GPU. Prompt length = 107, generation length = 1000.
Here's the perplexity comparasion between the original and the PR version. Tested on wikitext-2 using FP16, offloading all layers to GPU.
test-backend-ops perf tests:
TODO: